package cn.initcap.algorithm.graph.util;

import cn.initcap.algorithm.graph.Edge;
import cn.initcap.algorithm.graph.Graph;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;

/**
 * 使用BellmanFord算法求最短路径(通过松弛操作，处理负权边)，这是一个有向图算法，对无向图（没有负权边）是成立的
 * 对于无向图，有负权边等同于有负权环
 * 对所有边进行V - 1次松弛操作，则求出了到所有点，经过的边数最多为v - 1的最短路
 * 如果对所有边再进行一次松弛操作，还能更新disTo，则说明图中包含负权环
 *
 * @author initcap
 * @date Created in 1/21/19 12:06 PM.
 */
public class BellmanFord<Weight extends Number & Comparable> {

    /**
     * from[i]记录最短路径中, 到达i点的边是哪一条
     */
    Edge[] from;
    /**
     * 可以用来恢复整个最短路径
     * 标记图中是否有负权环
     */
    boolean hasNegativeCycle;
    /**
     * 图的引用
     */
    private Graph graph;
    /**
     * 起始点
     */
    private int s;
    /**
     * distTo[i]存储从起始点s到i的最短路径长度
     */
    private Number[] distTo;

    /**
     * 构造函数, 使用BellmanFord算法求最短路径
     *
     * @param graph
     * @param s
     */
    public BellmanFord(Graph graph, int s) {

        this.graph = graph;
        this.s = s;
        distTo = new Number[this.graph.nodeSize()];
        from = new Edge[this.graph.nodeSize()];
        // 初始化所有的节点s都不可达, 由from数组来表示
        for (int i = 0; i < this.graph.nodeSize(); i++) {
            from[i] = null;
        }

        // 设置distTo[s] = 0, 并且让from[s]不为NULL, 表示初始s节点可达且距离为0
        distTo[s] = 0.0;
        // 这里我们from[s]的内容是new出来的, 注意要在析构函数里delete掉
        from[s] = new Edge<>(s, s, (Weight) (Number) (0.0));

        // Bellman-Ford的过程
        // 进行V-1次循环, 每一次循环求出从起点到其余所有点, 最多使用pass步可到达的最短距离
        for (int pass = 1; pass < this.graph.nodeSize(); pass++) {

            // 每次循环中对所有的边进行一遍松弛操作
            // 遍历所有边的方式是先遍历所有的顶点, 然后遍历和所有顶点相邻的所有边
            for (int i = 0; i < this.graph.nodeSize(); i++) {
                // 使用我们实现的邻边迭代器遍历和所有顶点相邻的所有边
                for (Object item : this.graph.adj(i)) {
                    Edge<Weight> e = (Edge<Weight>) item;
                    // 对于每一个边首先判断e->v()可达
                    // 之后看如果e->w()以前没有到达过， 显然我们可以更新distTo[e->w()]
                    // 或者e->w()以前虽然到达过, 但是通过这个e我们可以获得一个更短的距离, 即可以进行一次松弛操作, 我们也可以更新distTo[e->w()]
                    if (from[e.v()] != null && (from[e.w()] == null || distTo[e.v()].doubleValue() + e.wt().doubleValue() < distTo[e.w()].doubleValue())) {
                        distTo[e.w()] = distTo[e.v()].doubleValue() + e.wt().doubleValue();
                        from[e.w()] = e;
                    }
                }
            }
        }

        hasNegativeCycle = detectNegativeCycle();
    }

    /**
     * 判断图中是否有负权环
     *
     * @return
     */
    boolean detectNegativeCycle() {

        for (int i = 0; i < graph.nodeSize(); i++) {
            for (Object item : graph.adj(i)) {
                Edge<Weight> e = (Edge<Weight>) item;
                if (from[e.v()] != null && distTo[e.v()].doubleValue() + e.wt().doubleValue() < distTo[e.w()].doubleValue()) {
                    return true;
                }
            }
        }

        return false;
    }

    /**
     * 返回图中是否有负权环
     *
     * @return
     */
    boolean negativeCycle() {
        return hasNegativeCycle;
    }

    /**
     * 返回从s点到w点的最短路径长度
     *
     * @param w
     * @return
     */
    Number shortestPathTo(int w) {
        assert w >= 0 && w < graph.nodeSize();
        assert !hasNegativeCycle;
        assert hasPathTo(w);
        return distTo[w];
    }

    /**
     * 判断从s点到w点是否联通
     *
     * @param w
     * @return
     */
    boolean hasPathTo(int w) {
        assert (w >= 0 && w < graph.nodeSize());
        return from[w] != null;
    }

    /**
     * 寻找从s到w的最短路径, 将整个路径经过的边存放在vec中
     *
     * @param w
     * @return
     */
    List<Edge<Weight>> shortestPath(int w) {

        assert w >= 0 && w < graph.nodeSize();
        assert !hasNegativeCycle;
        assert hasPathTo(w);

        // 通过from数组逆向查找到从s到w的路径, 存放到栈中
        Deque<Edge<Weight>> s = new ArrayDeque<>();
        Edge<Weight> e = from[w];
        while (e.v() != this.s) {
            s.push(e);
            e = from[e.v()];
        }
        s.push(e);

        // 从栈中依次取出元素, 获得顺序的从s到w的路径
        List<Edge<Weight>> res = new ArrayList<>();
        while (!s.isEmpty()) {
            e = s.pop();
            res.add(e);
        }

        return res;
    }

    /**
     * 打印出从s点到w点的路径
     *
     * @param w
     */
    void showPath(int w) {

        assert (w >= 0 && w < graph.nodeSize());
        assert (!hasNegativeCycle);
        assert (hasPathTo(w));

        List<Edge<Weight>> res = shortestPath(w);
        for (int i = 0; i < res.size(); i++) {
            System.out.print(res.get(i).v() + " -> ");
            if (i == res.size() - 1) {
                System.out.println(res.get(i).w());
            }
        }
    }

}
